{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a44a9073",
   "metadata": {},
   "source": [
    "## Bayesian Networks"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "14ce3241",
   "metadata": {},
   "source": [
    "author: Jacob Schreiber <br>\n",
    "contact: jmschreiber91@gmail.com\n",
    "\n",
    "Bayesian networks are a general-purpose probabilistic model that are a superset of all others presented in pomegranate. Specifically, Bayesian networks are a way of factorizing a joint probability distribution across a graph structure, where the presence of an edge represents a directed dependency between two variables and the lack of an edge represents a conditional independence. Importantly, the lack of an edge does not represent true independence of two variables by itself, but a conditional independence."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1b6e4ad5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Populating the interactive namespace from numpy and matplotlib\n",
      "torch      : 1.13.0\n",
      "pomegranate: 1.0.0\n",
      "\n",
      "Compiler    : GCC 11.2.0\n",
      "OS          : Linux\n",
      "Release     : 4.15.0-208-generic\n",
      "Machine     : x86_64\n",
      "Processor   : x86_64\n",
      "CPU cores   : 8\n",
      "Architecture: 64bit\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%pylab inline\n",
    "import seaborn; seaborn.set_style('whitegrid')\n",
    "import torch\n",
    "\n",
    "%load_ext watermark\n",
    "%watermark -m -n -p torch,pomegranate"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc74c689",
   "metadata": {},
   "source": [
    "### Initialization and Fitting\n",
    "\n",
    "Similar to the hidden Markov model, the Bayesian network is comprised of a set of distributions and a graph structure connecting them. In this case, the graph is just a series of directed unweighted edges. Most Bayesian networks require that this graph is acyclic. However, because pomegranate uses a factor graph to do inference, there is no strict requirement that this is the case. See the inference sections below.\n",
    "\n",
    "Likewise, similar to the other models in pomegranate, a Bayesian network can be learned in its entirety from data. However, exact structure learning is intractable and so the field has developed a variety of approximations. See the Bayesian network structure learning tutorial for more.\n",
    "\n",
    "For now, let's implement the simplest Bayesian network: a child and a parent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ee0487b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pomegranate.distributions import Categorical\n",
    "from pomegranate.distributions import ConditionalCategorical\n",
    "from pomegranate.bayesian_network import BayesianNetwork\n",
    "\n",
    "d1 = Categorical([[0.1, 0.9]])\n",
    "d2 = ConditionalCategorical([[[0.4, 0.6], [0.3, 0.7]]])\n",
    "\n",
    "model = BayesianNetwork([d1, d2], [(d1, d2)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51b582aa",
   "metadata": {},
   "source": [
    "Alternatively, one can use the `add_distributions` and `add_edge` method to create the network programmatically."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "39494b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "model2 = BayesianNetwork()\n",
    "model2.add_distributions([d1, d2])\n",
    "model2.add_edge(d1, d2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0ad8a0c",
   "metadata": {},
   "source": [
    "Once these models are initialized with a structure, they can be fit to data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "94cc3184",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[0, 0],\n",
       "       [0, 0],\n",
       "       [0, 1],\n",
       "       [0, 0],\n",
       "       [0, 1],\n",
       "       [0, 1],\n",
       "       [0, 0],\n",
       "       [1, 1],\n",
       "       [1, 0],\n",
       "       [0, 1]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X = numpy.random.randint(2, size=(10, 2))\n",
    "X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "2d0bba9b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "BayesianNetwork(\n",
       "  (distributions): ModuleList(\n",
       "    (0): Categorical()\n",
       "    (1): ConditionalCategorical(\n",
       "      (probs): ParameterList(  (0): Parameter containing: [torch.float32 of size 2x2])\n",
       "      (_w_sum): [tensor([0., 0.])]\n",
       "      (_xw_sum): [tensor([[0., 0.],\n",
       "              [0., 0.]])]\n",
       "      (_log_probs): [tensor([[-0.6931, -0.6931],\n",
       "              [-0.6931, -0.6931]])]\n",
       "    )\n",
       "  )\n",
       "  (_factor_graph): FactorGraph(\n",
       "    (factors): ModuleList(\n",
       "      (0): Categorical()\n",
       "      (1): JointCategorical()\n",
       "    )\n",
       "    (marginals): ModuleList(\n",
       "      (0): Categorical()\n",
       "      (1): Categorical()\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(X)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "328bcb71",
   "metadata": {},
   "source": [
    "Next, we can check that the learned parameters are correct."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4a4d34bd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([[0.8000, 0.2000]]),\n",
       " 0.2)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[0].probs, X[:,0].mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8cc48140",
   "metadata": {},
   "source": [
    "Looks like the model learned a marginal distribution where the probability of 1 is equal to the mean value, which is correct.\n",
    "\n",
    "We can also look at the conditional probability table and compare against the probability that the second column is 1 (by taking the mean) when the first column is 0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4f32e05e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([[0.5000, 0.5000],\n",
       "         [0.5000, 0.5000]]),\n",
       " 0.5)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[1].probs[0], (X[X[:,0] == 0][:,1]).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b4c1d7c",
   "metadata": {},
   "source": [
    "Also looks correct.\n",
    "\n",
    "Finally, if we know the structure of the Bayesian network that we'd like to learn but don't know the parameters of the tables, we can pass the structure into the initialization and call the fit function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "72ed676e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.5000, 0.5000],\n",
       "        [0.5000, 0.5000]])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model3 = BayesianNetwork(structure=[(), (0,)])\n",
    "model3.fit(X)\n",
    "\n",
    "model3.distributions[1].probs[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e9fae78",
   "metadata": {},
   "source": [
    "### Probabilities and Log Probabilities"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67163cab",
   "metadata": {},
   "source": [
    "Much like other generative models, Bayesian networks can calculate the probabilities of examples given the distributions and graph structure. Because the Bayesian network is just a factorization of the joint probability table along the graph structure, the probability of an example is just the product of the probability of each variable given its parents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d122a19c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.4000, 0.1000, 0.1000,\n",
       "        0.4000])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.probability(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5c38c8fa",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -2.3026,\n",
       "        -2.3026, -0.9163])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.log_probability(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c9ff8acb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -0.9163, -2.3026,\n",
       "        -2.3026, -0.9163])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[0].log_probability(X[:,:1]) + model.distributions[1].log_probability(X[:, :, None])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40e7863d",
   "metadata": {},
   "source": [
    "Similarly to other models in pomegranate, the inputs can be lists, numpy arrays, or torch tensors."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ed1ed550",
   "metadata": {},
   "source": [
    "### Prediction\n",
    "\n",
    "Perhaps the most useful application of a learned Bayesian network is the ability to do inference for missing values. Rather than a traditional prediction problem, which has a fixed set of inputs and one or more fixed outputs, Bayesian network inference will use any variables whose values are known to infer any variables whose values are not known. The set of known variables can change across examples, and so do not need to be known in advance.\n",
    "\n",
    "In pomegranate, this is done using the loopy belief propagation algorithm, sometimes also called the \"sum-product\" algorithm. This algorithm is run on a factor graph, which is constructed in the backend. The trade-offs for this, versus normal junction-tree inference, are that the algorithm is faster, easier to implement, exact for tree-like Bayesian networks, and can provide estimates even for cyclic networks, but that the inference is not guaranteed to be exact in other cases or even to converge when the network is cyclic.\n",
    "\n",
    "The implementation of the prediction methods differs slightly from other models in pomegranate. First, the unobserved variables are indicated using a `torch.masked.MaskedTensor` object, which holds the underlying data and a mask where `True` means the value is observed and `False` means that it is not observed. When the mask is `False`, it does not matter what the underlying value is. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "09d83981",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jmschr/anaconda3/lib/python3.9/site-packages/torch/masked/maskedtensor/core.py:156: UserWarning: The PyTorch API of MaskedTensors is in prototype stage and will change in the near future. Please open a Github issue for features requests and see our documentation on the torch.masked module for further information about the project.\n",
      "  warnings.warn((\"The PyTorch API of MaskedTensors is in prototype stage \"\n"
     ]
    }
   ],
   "source": [
    "X_torch = torch.tensor(X[:4])\n",
    "mask = torch.tensor([[True, False],\n",
    "                     [False, True],\n",
    "                     [True, True],\n",
    "                     [False, False]])\n",
    "\n",
    "X_masked = torch.masked.MaskedTensor(X_torch, mask=mask)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f31835e",
   "metadata": {},
   "source": [
    "If you have already set values in your tensor to some missing value, such as -1, you can easily just do `torch.masked.MaskedTensor(X, mask=(X != -1))`.\n",
    "\n",
    "Once you've created your `MaskedTensor` you can pass it into the `predict`, `predict_proba`, or `predict_log_proba` methods of your Bayesian network. As a note, when data is provided, those values will appear in the output with a probability of 1. Probability distributions are only provided for unobserved variables."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3f9a8d54",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[1.0000, 0.0000],\n",
       "         [0.8000, 0.2000],\n",
       "         [1.0000, 0.0000],\n",
       "         [0.8000, 0.2000]]),\n",
       " tensor([[0.5000, 0.5000],\n",
       "         [1.0000, 0.0000],\n",
       "         [0.0000, 1.0000],\n",
       "         [0.5000, 0.5000]])]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_proba(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "69c03439",
   "metadata": {},
   "source": [
    "You might notice that the output from these functions is a different shape than other methods. Because there is no guarantee that the variables all have the same number of categories, pomegranate cannot return a single tensor where one of the dimensions is the number of categories. Instead, pomegranate chooses to return a list of tensors, where each element in the list is one variable and the tensor has the dimensions `(n_examples, n_categories)` for the number of categories for that dimension. In principle, one could return a single tensor of size `(n_examples, n_dimensions, max_n_categories)` where `max_n_categories` is the maximum number of categories across all dimensions, but one would likely choose to slice the unnecessary categories out anyway, and there is no guarantee that a single variable with a large number of categories wouldn't come along and massively increase the amount of needed memory. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e214351f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([[ 0.0000,    -inf],\n",
       "         [-0.2231, -1.6094],\n",
       "         [ 0.0000,    -inf],\n",
       "         [-0.2231, -1.6094]]),\n",
       " tensor([[-0.6931, -0.6931],\n",
       "         [ 0.0000,    -inf],\n",
       "         [   -inf,  0.0000],\n",
       "         [-0.6931, -0.6931]])]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict_log_proba(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e0d0d976",
   "metadata": {},
   "source": [
    "Unlike the previous two methods, the `predict` method will preserve the shape of the original data but replace the missing values, according to the mask, with the maximally likely value from the `predict_proba` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "9eccf44f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0, 0],\n",
       "        [0, 0],\n",
       "        [0, 1],\n",
       "        [0, 0]])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.predict(X_masked)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87ecf843",
   "metadata": {},
   "source": [
    "### Summarization"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ba843259",
   "metadata": {},
   "source": [
    "Summarization for Bayesian networks works entirely the same way as summarization for other models. When given complete data, summary statistics are derived using MLE that can be added together across batches. This means that Bayesian networks can be learned in an out-of-core manner. \n",
    "\n",
    "Importantly, the Bayesian network must already have a structure or it will not know what statistics to calculate per-batch. This is because structure learning is not implemented in an out-of-core manner and would otherwise have to be done on the first batch. If this is what you want to do, then you should use `fit` on the first batch and the `summarize` on successive batches."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "274a11f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "X = torch.randint(2, size=(20, 4))\n",
    "\n",
    "d1 = Categorical([[0.1, 0.9]])\n",
    "d2 = Categorical([[0.3, 0.7]])\n",
    "d3 = ConditionalCategorical([[[0.4, 0.6], [0.3, 0.7]]])\n",
    "d4 = ConditionalCategorical([[[[0.8, 0.2], [0.1, 0.9]], [[0.1, 0.9], [0.5, 0.5]]]])\n",
    "\n",
    "model = BayesianNetwork([d1, d2, d3, d4], [(d1, d3), (d2, d4), (d3, d4)])\n",
    "model.summarize(X[:5])\n",
    "model.summarize(X[5:])\n",
    "model.from_summaries()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "781d5d33",
   "metadata": {},
   "source": [
    "After fitting, we can check what the learned parameters are."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "29901f53",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Parameter containing:\n",
       " tensor([[0.6000, 0.4000]]),\n",
       " Parameter containing:\n",
       " tensor([[0.5500, 0.4500]]))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[0].probs, model.distributions[1].probs"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "19f3af2b",
   "metadata": {},
   "source": [
    "And we can compare them to the actual averages from the data, which should be the same as the learned parameters for the first two dimensions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "32d18a0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4000, 0.4500, 0.4500, 0.6000])"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.mean(X.type(torch.float32), dim=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecc2b7f1",
   "metadata": {},
   "source": [
    "### Sampling"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3068aead",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_sample = model.sample(1000000).type(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "b9df335f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[(), (), (0,), (1, 2)]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model._parents"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3b020e0f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.6000, 0.4000]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[0].probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2285d738",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.5500, 0.4500]])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[1].probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "77bbc355",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.4006, 0.4501])"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_sample[:, :2].mean(dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1384cfc0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[0.4167, 0.5833],\n",
       "        [0.7500, 0.2500]])"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[2].probs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "2bb22d29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(0.5836), tensor(0.2500))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_sample[X_sample[:, 0] == 0, 2].mean(), X_sample[X_sample[:, 0] == 1, 2].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "93cca0b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Parameter containing:\n",
       "tensor([[[0.5000, 0.5000],\n",
       "         [0.2000, 0.8000]],\n",
       "\n",
       "        [[0.4000, 0.6000],\n",
       "         [0.5000, 0.5000]]])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.distributions[3].probs[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "4eaa1041",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.5998)"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_sample[(X_sample[:, 1] == 1) & (X_sample[:, 2] == 0), 3].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d5379938",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
